import torch
from torch import nn
import torch.nn.functional as F
from ..utils.pe import select_pe_encoder
from ..utils import create_norm, create_activation
import numpy as np

class OmicsEmbedder(nn.Module):
    def __init__(self, pretrained_gene_list, num_hid, gene_emb=None, fix_embedding=False):
        super().__init__()
        self.pretrained_gene_list = pretrained_gene_list
        self.gene_index = dict(zip(pretrained_gene_list, list(range(len(pretrained_gene_list)))))
        self.num_hid = num_hid

        if gene_emb is not None:
            self.emb = nn.Parameter(gene_emb, requires_grad=not fix_embedding)
        else:
            self.emb = nn.Parameter(torch.randn([len(pretrained_gene_list), num_hid], dtype=torch.float32))
            if fix_embedding:
                self.emb.requires_grad = False

    def forward(self, x_dict, input_gene_list=None):
        if 'masked_x_seq' in x_dict:
            x = x_dict['masked_x_seq']
        else:
            x = x_dict['x_seq']
        if input_gene_list is not None:
            gene_idx = torch.tensor([self.gene_index[o] for o in input_gene_list if o in self.gene_index]).long()
        else:
            if x.shape[1] != len(self.pretrained_gene_list):
                raise ValueError('The input gene size is not the same as the pretrained gene list. Please provide the input gene list.')
            gene_idx = torch.arange(x.shape[1]).long()
        gene_idx = gene_idx.to(x.device)
        feat = F.embedding(gene_idx, self.emb)
        feat = torch.sparse.mm(x, feat)
        return feat

class OmicsEmbeddingLayer(nn.Module):
    def __init__(self, gene_list, num_hidden, norm, activation='gelu', dropout=0.3, pe_type=None, cat_pe=True, gene_emb=None):
        super().__init__()

        self.pe_type = pe_type
        self.cat_pe = cat_pe
        self.act = create_activation(activation)
        self.norm0 = create_norm(norm, num_hidden)
        self.dropout = nn.Dropout(dropout)
        if pe_type is not None:
            if cat_pe:
                num_emb = num_hidden // 2
            else:
                num_emb = num_hidden
            self.pe_enc = select_pe_encoder(pe_type)(num_emb)
        else:
            self.pe_enc = None
            num_emb = num_hidden

        if gene_emb is None:
            self.feat_enc = OmicsEmbedder(gene_list, num_emb)
        else:
            self.feat_enc = OmicsEmbedder(gene_list, num_emb, gene_emb)

    def forward(self, x_dict, input_gene_list=None):
        x = self.act(self.feat_enc(x_dict, input_gene_list))

        if self.pe_enc is not None:
            pe_input = x_dict[self.pe_enc.pe_key]
            pe = self.pe_enc(pe_input)
            if self.cat_pe:
                x = torch.cat([x, pe], 1)
            else:
                x = x + pe

        x = self.norm0(self.dropout(x))
        return x

# class OmicsGraphBuilder:
#     def __init__(self, pretrained_gene_list):
#         self.pretrained_gene_list = pretrained_gene_list
#         self.gene_index = dict(zip(pretrained_gene_list, list(range(len(pretrained_gene_list)))))
#
#     def build(self, x, batch_labels=None, input_gene_list=None, training=True):
#         g = dgl.bipartite_from_scipy(x, utype='cell', etype='express', vtype='gene', eweight_name='count')
#         if input_gene_list is not None:
#             g.nodes['gene'].data['id'] = torch.tensor([self.gene_index[o] for o in input_gene_list]).long() + 1
#         else:
#             g.nodes['gene'].data['id'] = torch.arange(x.shape[1]).long() + 1
#         g.nodes['cell'].data['id'] = torch.zeros(x.shape[0]).long()
#
#         if batch_labels is not None:
#             g.nodes['cell'].data['batch'] = batch_labels
#             g.nodes['gene'].data['batch'] = torch.zeros(x.shape[1]).long() + batch_labels.max() + 1
#             g = dgl.to_homogeneous(g, ndata=['id', 'batch'], edata=['count'])
#         else:
#             g = dgl.to_homogeneous(g, ndata=['id'], edata=['count'])
#
#         if training:
#             g.ndata['feat'] = F.pad(torch.from_numpy(np.asarray(x.todense())), (0, 0, 0, x.shape[1]), value=0.)
#         g = dgl.add_self_loop(g)
#         return g


